import torch.nn as nn


class DecoderBase(nn.Module):
    r"""
        The base class for decoder

    Parameters
    ----------
    use_attention: bool, default=True
        Whether to use attention mechanism
    use_copy: bool, default=False
        Whether to use copy attention(pointer network) mechanism
    use_coverage: bool, default=False
        Whether to use coverage attention mechanism
    attention_type: str, option=["uniform", "separate"], default="uniform"
        The attention strategy choice.
        "``uniform``": uniform attention. We will attend on the nodes uniformly.
        "``sep_diff_encoder_type``": separate attention.
                                    We will attend on different encoder results separately.
        "``sep_diff_node_type``": separate attention.
                                    We will attend on different node type separately.
    attention_function_type: str, option=["mlp", "dot", "general"]


    fuse_strategy: str, option=[None, "average", "concatenate"], default=None
        The strategy to fuse attention results generated by separate attention.
        "None": If we do ``uniform`` attention, we will set it to None.
        "``average``": We will take an average on all results.
        "``concatenate``": We will concatenate all results to one.
    """

    def __init__(
        self,
        use_attention=True,
        use_copy=False,
        use_coverage=False,
        attention_type="uniform",
        fuse_strategy="average",
    ):
        super(DecoderBase, self).__init__()
        self.use_attention = use_attention
        self.use_copy = use_copy
        self.use_coverage = use_coverage
        self.attention_type = attention_type
        self.fuse_strategy = fuse_strategy

    def forward(self, **kwargs):
        r"""
        Forward calculation method
        """
        raise NotImplementedError()

    def decode_step(self, **kwargs):
        r"""
        One step for decoding
        """
        raise NotImplementedError()

    def get_decoder_init_state(self, **kwargs):
        r"""
            The initial state for decoding.
        -------

        """
        raise NotImplementedError()


class RNNDecoderBase(DecoderBase):
    def __init__(
        self,
        use_attention=True,
        use_copy=False,
        use_coverage=False,
        attention_type="uniform",
        fuse_strategy="average",
    ):
        super(RNNDecoderBase, self).__init__(
            use_attention=use_attention,
            use_copy=use_copy,
            use_coverage=use_coverage,
            attention_type=attention_type,
            fuse_strategy=fuse_strategy,
        )

    def _build_rnn(self, **kwargs):
        """
        The private function to design the specific RNN models.
        MUST be override by subclass
        """
        raise NotImplementedError()

    def extract_params(self, g):
        raise NotImplementedError()

    def forward(self, g, tgt_seq=None, src_seq=None):
        params = self._extract_params(g)
        return self._run_forward_pass(**params, tgt_seq=tgt_seq)

    def _run_forward_pass(
        self,
        graph_node_embedding,
        graph_node_mask,
        rnn_node_embedding,
        graph_level_embedding,
        graph_edge_embedding=None,
        graph_edge_mask=None,
        tgt_seq=None,
    ):
        r"""
            The private calculation method for decoder.

        Parameters
        ----------
        tgt_seq: torch.Tensor,
            The target sequence of shape :math:`(B, seq_len)`, where :math:`B`
            is size of batch, :math:`seq_len` is the length of sequences.
            Note that it is consisted by tokens' index.
        graph_node_embedding: torch.Tensor,
            The graph node embedding matrix of shape :math:`(B, N, D_{in})`
        graph_node_mask: torch.Tensor,
            The graph node type mask matrix of shape :math`(B, N)`
        rnn_node_embedding: torch.Tensor,
            The rnn encoded embedding matrix of shape :math`(B, N, D_{in})`
        graph_level_embedding: torch.Tensor,
            graph level embedding of shape :math`(B, D_{in})`
        graph_edge_embedding: torch.Tensor,
            graph edge embedding of shape :math`(B, N, D_{in})`
        graph_edge_mask: torch.Tensor,
            graph edge type embedding

        Returns
        ----------
        logits: torch.Tensor
        attns: torch.Tensor
        """
        raise NotImplementedError()

    def decode_step(
        self,
        decoder_input,
        rnn_state,
        encoder_out,
        dec_input_mask,
        rnn_emb=None,
        enc_attn_weights_average=None,
        src_seq=None,
        oov_dict=None,
    ):
        r"""
            One step for decoding
        Parameters
        ----------
        decoder_input: torch.Tensor
            The input for current decoding step
        rnn_state: torch.Tensor
            Rnn_state
        encoder_out: torch.Tensor
            The graph node embedding for decoding
        dec_input_mask: torch.Tensor
            The mask of graph node.
            Notes: ``-1`` is the dummy node, each int larger than -1 is one class for separate attention. # noqa
        rnn_emb: torch.Tensor
            The graph node embedding from RNN encoder.
        enc_attn_weights_average: list
            The list of encoder attention weights. It will be used for coverage.
        src_seq: torch.Tensor
            The source sequence. It will be used for copy.
        oov_dict: Vocab
            The vocabulary containing out-of-vocabulary words.
        """
        raise NotImplementedError()

    def get_decoder_init_state(self, rnn_type, batch_size, content=None):
        r"""
            The initial state for RNN decoder.
        Parameters
        ----------
        rnn_type: str, option=["LSTM", "GRU']
            The rnn type.
        batch_size: int
            The batch size of the initial state.
        content: torch.Tensor, default=None
            The initialization of initial state.

        Returns
        -------
            initial_state: Any
        """
        raise NotImplementedError()


class RNNTreeDecoderBase(DecoderBase):
    r"""
        The base class for rnn tree decoder

    Parameters
    ----------
    attentional: bool, default=True
        Whether to use attention mechanism
    use_copy: bool, default=False
        Whether to use copy attention(pointer network) mechanism
    use_coverage: bool, default=False
        Whether to use coverage attention mechanism
    attention_type: str, option=["uniform", "separate"], default="uniform"
        The attention strategy choice.
        "``uniform``": uniform attention. We will attend on the nodes uniformly.
        "``separate_different_encoder_type``": separate attention.
                                               We will attend on different encoder results separately. # noqa
        "``separate_different_node_type``": separate attention.
                                            We will attend on different node type separately.

    fuse_strategy: str, option=[None, "average", "concatenate"], default=None
        The strategy to fuse attention results generated by separate attention.
        "None": If we do ``uniform`` attention, we will set it to None.
        "``average``": We will take an average on all results.
        "``concatenate``": We will concatenate all results to one.
    teacher_forcing_ratio: float, default=1.
        The probability rate to use teacher forcing strategy.
    """

    def __init__(
        self,
        use_attention=True,
        use_copy=False,
        use_coverage=False,
        attention_type="uniform",
        fuse_strategy="average",
    ):
        super(RNNTreeDecoderBase, self).__init__(use_attention=use_attention)

    def _build_rnn(self, rnn_type, **kwargs):
        """
        The private function to design the specific RNN models.
        MUST be override by subclass
        """
        raise NotImplementedError()

    def forward(self, encoder_output_dict, tgt_tree_batch, enc_batch):
        return self._run_forward_pass(
            **encoder_output_dict, tgt_tree_batch=tgt_tree_batch, enc_batch=enc_batch
        )

    def _run_forward_pass(
        self,
        graph_node_embedding,
        graph_node_mask,
        rnn_node_embedding,
        graph_level_embedding,
        graph_edge_embedding=None,
        graph_edge_mask=None,
        tgt=None,
        enc_batch=None,
    ):
        r"""
            The private calculation method for decoder.

        Parameters
        ----------
        tgt: torch.Tensor,
            TODO: add tree decode batch data description
        graph_node_embedding: torch.Tensor,
            The graph node embedding matrix of shape :math:`(B, N, D_{in})`
        graph_node_mask: torch.Tensor,
            The graph node type mask matrix of shape :math`(B, N)`
        rnn_node_embedding: torch.Tensor,
            The rnn encoded embedding matrix of shape :math`(B, N, D_{in})`
        graph_level_embedding: torch.Tensor,
            graph level embedding of shape :math`(B, D_{in})`
        graph_edge_embedding: torch.Tensor,
            graph edge embedding of shape :math`(B, N, D_{in})`
        graph_edge_mask: torch.Tensor,
            graph edge type embedding

        Returns
        ----------
        logits: torch.Tensor
        attns: torch.Tensor
        """
        raise NotImplementedError()
